#!/usr/bin/env python

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.

See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms

from utils.timer import Timer
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import argparse

from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1
import scipy
from shapely.geometry import Polygon
from math import acos


pi     = scipy.pi
dot    = scipy.dot
sin    = scipy.sin
cos    = scipy.cos
ar     = scipy.array

total_found = 0
total_count = 0
total_predicted = 0
total_predicted_and_matched = 0

CLASSES = ('__background__',
           'angle_01', 'angle_02', 'angle_03', 'angle_04', 'angle_05',
           'angle_06', 'angle_07', 'angle_08', 'angle_09', 'angle_10',
           'angle_11', 'angle_12', 'angle_13', 'angle_14', 'angle_15',
           'angle_16', 'angle_17', 'angle_18', 'angle_19')

NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',),'res50': ('res50_faster_rcnn_iter_159000.ckpt',)}
DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',),'grasp': ('train',)}

def dotproduct(a,b):
	return sum([a[i]*b[i] for i in range(len(a))])
 
#Calculates the size of a vector
def veclength(a):
	return sum([a[i]**2 for i in range(len(a))])**.5
 
#Calculates the angle between two vector
def ange(a,b):
    dp=dotproduct(a,b)
    la=veclength(a)
    lb=veclength(b)
    costheta=dp/(la*lb)
    degree = acos(costheta)/pi*180
    if degree > 90:
        degree = 180 - degree
    return degree


def Rotate2D(pts,cnt,ang=scipy.pi/4):
    '''pts = {} Rotates points(nx2) about center cnt(2) by angle ang(1) in radian'''
    return dot(pts-cnt,ar([[cos(ang),sin(ang)],[-sin(ang),cos(ang)]]))+cnt

def vis_detections(gt_file, ax, image_name, im, class_name, dets, thresh=0.1):
    """Draw detected bounding boxes."""
    overlap_found = False
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        missed = True
        return overlap_found, 0, 0

    gt_x_ar = []
    gt_y_ar = []
    with open(gt_file) as f:
       for line in f:
         data = line.split()
         gt_x_ar.append(float(data[0]))
         gt_y_ar.append(float(data[1]))

    gt_label_polygon_list = []
    for i in range(0,len(gt_x_ar),4):
         gt_label_polygon = Polygon([(gt_x_ar[i], gt_y_ar[i]), (gt_x_ar[i+1], gt_y_ar[i+1]), (gt_x_ar[i+2], gt_y_ar[i+2]), (gt_x_ar[i+3], gt_y_ar[i+3])])
         gt_label_polygon_list.append(gt_label_polygon)
         #gt_x, gt_y = gt_label_polygon.exterior.xy
         #plt.plot(gt_x,gt_y, color='r', alpha = 0.7, linewidth=5, solid_capstyle='round', zorder=2)

    im = im[:, :, (2, 1, 0)]
    #fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    predicted_and_matched_num = 0
    for i in inds:
        inds_found = False
        bbox = dets[i, :4]
        score = dets[i, -1]

        pts = ar([[bbox[0],bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]])
        cnt = ar([(bbox[0] + bbox[2])/2, (bbox[1] + bbox[3])/2])
        angle = int(class_name[6:])
        r_bbox = Rotate2D(pts, cnt, -pi/2-pi/20*(angle-1))
        pred_label_polygon = Polygon([(r_bbox[0,0],r_bbox[0,1]), (r_bbox[1,0], r_bbox[1,1]), (r_bbox[2,0], r_bbox[2,1]), (r_bbox[3,0], r_bbox[3,1])])
        pred_x, pred_y = pred_label_polygon.exterior.xy
        plt.plot(pred_x[0:2],pred_y[0:2], color='k', alpha = 0.7, linewidth=1, solid_capstyle='round', zorder=2)
        plt.plot(pred_x[1:3],pred_y[1:3], color='r', alpha = 0.7, linewidth=3, solid_capstyle='round', zorder=2)
        plt.plot(pred_x[2:4],pred_y[2:4], color='k', alpha = 0.7, linewidth=1, solid_capstyle='round', zorder=2)
        plt.plot(pred_x[3:5],pred_y[3:5], color='r', alpha = 0.7, linewidth=3, solid_capstyle='round', zorder=2)

        for gt_label_polygon_ind in gt_label_polygon_list:
           intersection = pred_label_polygon.intersection(gt_label_polygon_ind)
           union = pred_label_polygon.union(gt_label_polygon_ind)
           if intersection.area/union.area > 0.25:
              gt_x, gt_y = gt_label_polygon_ind.exterior.xy
              degree = ange(np.array([pred_x[0] - pred_x[1], pred_y[0] - pred_y[1]]), np.array([gt_x[0] - gt_x[1], gt_y[0] - gt_y[1]]))
              if degree < 10:
                 overlap_found = True
                 inds_found = True
                 #plt.plot(gt_x,gt_y, color='b', alpha = 0.7, linewidth=3, solid_capstyle='round', zorder=2)
                 #plt.plot(gt_x[0:2],gt_y[0:2], color='k', alpha = 0.7, linewidth=1, solid_capstyle='round', zorder=2)
                 #plt.plot(gt_x[1:3],gt_y[1:3], color='g', alpha = 0.7, linewidth=3, solid_capstyle='round', zorder=2)
                 #plt.plot(gt_x[2:4],gt_y[2:4], color='k', alpha = 0.7, linewidth=1, solid_capstyle='round', zorder=2)
                 #plt.plot(gt_x[3:5],gt_y[3:5], color='g', alpha = 0.7, linewidth=3, solid_capstyle='round', zorder=2)
        if inds_found is True: predicted_and_matched_num = predicted_and_matched_num+1
        #ax.text(bbox[0], bbox[1] - 2,
        #        '{:s} {:.3f}'.format(class_name, score),
        #        bbox=dict(facecolor='blue', alpha=0.5),
        #        fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
    return overlap_found, len(inds), predicted_and_matched_num

def demo(sess, net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""
    global total_found 
    global total_count 
    global total_missed

    global total_predicted
    global total_predicted_and_matched
   
    # Load the demo image
    im_file = image_name
    gt_file_base = image_name[:-36]
    im = cv2.imread(im_file)
    image_name = os.path.basename(im_file)
    image_name = os.path.splitext(image_name)[0]
    gt_file = gt_file_base + '/' + image_name[:7] + 'cposCropped320.txt'
    print ('GT file: ' + gt_file) 

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)

    # need rgb for visualization?
    im_rgb_file = gt_file_base + '/' + image_name[:7] + 'rCropped320.png'
    im = cv2.imread(im_rgb_file)

    # only consider one bbox with highest score?
    scores_max = scores[:,1:-1].max(axis=1)
    scores_max_idx = np.argmax(scores_max)
    scores = scores[scores_max_idx:scores_max_idx+1,:]
    boxes = boxes[scores_max_idx:scores_max_idx+1, :]
    scores[0,0]=-1 # make sure background wont be the max
    scores_max_idx = np.argmax(scores)


    timer.toc()
    print (('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0]))

    fig, ax = plt.subplots(figsize=(12, 12))
    # Visualize detections for each class
    CONF_THRESH = 0.0000001
    NMS_THRESH = 0.7

    overall_overlap_found = False
    overall_missed = True
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        if scores_max_idx != cls_ind:
            continue

        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        tmp_overlap_found, tmp_predictedNum, tmp_predictedMatchedNum = vis_detections(gt_file, ax, image_name, im, cls, dets, thresh=CONF_THRESH)
        #print tmp_overlap_found
        if tmp_overlap_found == True: overall_overlap_found = True
        total_predicted = total_predicted + tmp_predictedNum
        total_predicted_and_matched = total_predicted_and_matched + tmp_predictedMatchedNum

    total_count = total_count + 1
    if overall_overlap_found == True: total_found = total_found + 1
    print ('recall = ' + str(total_found) + '/' + str(total_count) + ' (the rate to find one grasp in one image)')
    print ('precision = ' + str(total_predicted_and_matched) + '/' + str(total_predicted) + ' (the rate per predicted grasp passed Jaccard)')
    plt.axis('off')
    plt.tight_layout() 

    #save result
    savepath = './data/demo/results_pure/results_pure_pred/' + str(image_name) + '.png'    
    plt.savefig(savepath)    

    #plt.draw()

def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                        choices=NETS.keys(), default='res101')
    parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                        choices=DATASETS.keys(), default='pascal_voc_0712')
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals
    args = parse_args()

    # model path
    demonet = args.demo_net
    dataset = args.dataset
    tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',
                              NETS[demonet][0])


    if not os.path.isfile(tfmodel + '.meta'):
        raise IOError(('{:s} not found.\nDid you download the proper networks from '
                       'our server and place them properly?').format(tfmodel + '.meta'))

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth=True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if demonet == 'vgg16':
        net = vgg16(batch_size=1)
    elif demonet == 'res101':
        net = resnetv1(batch_size=1, num_layers=101)
    elif demonet == 'res50':
        net = resnetv1(batch_size=1, num_layers=50)
    else:
        raise NotImplementedError
    net.create_architecture(sess, "TEST", 20,
                          tag='default', anchor_scales=[8, 16, 32])
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)

    print('Loaded network {:s}'.format(tfmodel))

    print ('reading images..')
    im_names = []
    #with open('/home/fujenchu/projects/deepLearning/faster-rcnn-resnet/data/testCrop320rgd.txt') as f:
    with open('/media/fujenchu/home3/fasterrcnn_grasp/rgd_multibbs_5_5_5_object_tf/data/ImageSets/testfull.txt') as f:
       for line in f:
         im_path = ""
         for char in line:
           if char == '\n': break
           im_path = im_path + str(char)
         im_names.append(im_path)

    print ('total images: ' + str(len(im_names)))

    for im_name in im_names:
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print('evaluating {}'.format(im_name))
        demo(sess, net, im_name)

    #plt.show()
